{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d222740b",
   "metadata": {},
   "source": [
    "# Customize Your Own Pruners\n",
    "\n",
    "Torch-pruning is a scalable tool that enables you to create your own pruners with customized importance criteria and pruning schemes. For instance, you can use torch-pruning to implement the [Slimming pruner](https://arxiv.org/abs/1708.06519), which utilizes the scaling parameters in Batch Normalization (BN) to identify and remove unimportant channels. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ba962543",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "import sys, os\n",
    "sys.path.append(os.path.abspath(\"../\"))\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision.models import resnet18\n",
    "import torch_pruning as tp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a9d8d32",
   "metadata": {},
   "source": [
    "### 1. Pruner Definition\n",
    "\n",
    "Slimming Pruner leverages the scaling factor in Batch Normalization (BN) layers to determine the importance score of different channels. This technique follows a \"training-pruning-fine-tuning\" paradigm, which involves sparse training of the original model. In Torch-Pruning, the base class ``tp.pruner.MetaPruner`` provides a convenient ``.regularize(model)`` method for sparse training. Our first task is to implement such an interface to enable efficient regularization of BN parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "20ccf46e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MySlimmingPruner(tp.pruner.MetaPruner):\n",
    "    def regularize(self, model, reg):\n",
    "        for m in model.modules():\n",
    "            if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and m.affine==True:\n",
    "                m.weight.grad.data.add_(reg*torch.sign(m.weight.data)) # Lasso for sparsity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07dcfd75",
   "metadata": {},
   "source": [
    "### 2. Importance function\n",
    "Now, we need a new importance criterion for slimming, which compares the magnitude of the scaling parameter in BN. In this work, importance criterion is a callable function or object which accept a group ``tp.PruningGroup`` as inputs. ``tp.PruningGroup`` records all coupled layers as well as their pruning indices. We can scan the group to design our own importance function as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9f61c5aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MySlimmingImportance(tp.importance.Importance):\n",
    "    def __call__(self, group, **kwargs):\n",
    "        #note that we have multiple BNs in a group, \n",
    "        # we store layer-wise scores in a list and then reduce them to get the final results\n",
    "        group_imp = [] # (num_bns, num_channels) \n",
    "        # 1. iterate the group to estimate importance\n",
    "        for dep, idxs in group:\n",
    "            layer = dep.target.module # get the target model\n",
    "            prune_fn = dep.handler    # get the pruning function of target model, unused in this example\n",
    "            if isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and layer.affine:\n",
    "                local_imp = torch.abs(layer.weight.data)\n",
    "                group_imp.append(local_imp)\n",
    "        if len(group_imp)==0: return None # return None if the group contains no BN layer\n",
    "        # 2. reduce your group importance to a 1-D scroe vector. Here we use the average score across layers.\n",
    "        group_imp = torch.stack(group_imp, dim=0).mean(dim=0) \n",
    "        return group_imp # (num_channels, )\n",
    "\n",
    "# You can implement any importance functions, as long as it transforms a group to a 1-D score vector.\n",
    "class RandomImportance(tp.importance.Importance):\n",
    "    @torch.no_grad()\n",
    "    def __call__(self, group, **kwargs):\n",
    "        _, idxs = group[0]\n",
    "        return torch.rand(len(idxs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7e1e073",
   "metadata": {},
   "source": [
    "### 3. Pruning\n",
    "Now let's leverage the customized pruner to slim a resnet-18 models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "762220da",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = resnet18(pretrained=True)\n",
    "example_inputs = torch.randn(1, 3, 224, 224)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b68e113a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 0. importance criterion \n",
    "imp = MySlimmingImportance()\n",
    "\n",
    "# 1. ignore some layers that should not be pruned, e.g., the final classifier layer.\n",
    "ignored_layers = []\n",
    "for m in model.modules():\n",
    "    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:\n",
    "        ignored_layers.append(m) # DO NOT prune the final classifier!\n",
    "\n",
    "# 2. Pruner initialization\n",
    "iterative_steps = 5 # You can prune your model to the target pruning ratio iteratively.\n",
    "pruner = MySlimmingPruner(\n",
    "    model, \n",
    "    example_inputs, \n",
    "    global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers.\n",
    "    importance=imp, # importance criterion for parameter selection\n",
    "    iterative_steps=iterative_steps, # the number of iterations to achieve target pruning ratio\n",
    "    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}\n",
    "    ignored_layers=ignored_layers,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6a4b5b2",
   "metadata": {},
   "source": [
    "Sparse training with ``pruner.regularize``. Rember to regularize the model before ``optimizer.step()``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9fb7344a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training\n",
    "for _ in range(100):\n",
    "    pass\n",
    "    # optimizer.zero_grad()\n",
    "    # ...\n",
    "    # loss.backward()\n",
    "    # pruner.regularize(model, reg=1e-5)\n",
    "    # optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc447ab6",
   "metadata": {},
   "source": [
    "Pruning and finetuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ab5ee265",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResNet(\n",
      "  (conv1): Conv2d(3, 57, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(57, 115, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(115, 115, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(57, 115, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(115, 115, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(115, 115, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(115, 230, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(230, 230, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(115, 230, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(230, 230, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(230, 230, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(230, 460, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(460, 460, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(230, 460, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(460, 460, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(460, 460, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=460, out_features=1000, bias=True)\n",
      ")\n",
      "torch.Size([1, 1000])\n",
      "  Iter 1/5, Params: 11.69 M => 9.48 M\n",
      "  Iter 1/5, MACs: 1.82 G => 1.47 G\n",
      "================\n",
      "ResNet(\n",
      "  (conv1): Conv2d(3, 51, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(51, 102, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(102, 102, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(51, 102, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(102, 102, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(102, 102, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(102, 204, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(102, 204, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(204, 409, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(409, 409, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(204, 409, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(409, 409, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(409, 409, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=409, out_features=1000, bias=True)\n",
      ")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1000])\n",
      "  Iter 2/5, Params: 11.69 M => 7.53 M\n",
      "  Iter 2/5, MACs: 1.82 G => 1.18 G\n",
      "================\n",
      "ResNet(\n",
      "  (conv1): Conv2d(3, 44, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(44, 89, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(89, 89, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(44, 89, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(89, 89, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(89, 89, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(89, 179, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(179, 179, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(89, 179, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(179, 179, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(179, 179, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(179, 358, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(358, 358, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(179, 358, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(358, 358, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(358, 358, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=358, out_features=1000, bias=True)\n",
      ")\n",
      "torch.Size([1, 1000])\n",
      "  Iter 3/5, Params: 11.69 M => 5.82 M\n",
      "  Iter 3/5, MACs: 1.82 G => 0.91 G\n",
      "================\n",
      "ResNet(\n",
      "  (conv1): Conv2d(3, 38, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(38, 38, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(38, 38, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(38, 38, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(38, 38, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(38, 76, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(76, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(38, 76, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(76, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(76, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(76, 153, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(153, 153, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(76, 153, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(153, 153, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(153, 153, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(153, 307, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(307, 307, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(153, 307, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(307, 307, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(307, 307, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=307, out_features=1000, bias=True)\n",
      ")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1000])\n",
      "  Iter 4/5, Params: 11.69 M => 4.32 M\n",
      "  Iter 4/5, MACs: 1.82 G => 0.68 G\n",
      "================\n",
      "ResNet(\n",
      "  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=256, out_features=1000, bias=True)\n",
      ")\n",
      "torch.Size([1, 1000])\n",
      "  Iter 5/5, Params: 11.69 M => 3.06 M\n",
      "  Iter 5/5, MACs: 1.82 G => 0.49 G\n",
      "================\n"
     ]
    }
   ],
   "source": [
    "base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)\n",
    "for i in range(iterative_steps):\n",
    "    pruner.step()\n",
    "\n",
    "    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)\n",
    "    print(model)\n",
    "    print(model(example_inputs).shape)\n",
    "    print(\n",
    "        \"  Iter %d/%d, Params: %.2f M => %.2f M\"\n",
    "        % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6)\n",
    "    )\n",
    "    print(\n",
    "        \"  Iter %d/%d, MACs: %.2f G => %.2f G\"\n",
    "        % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9)\n",
    "    )\n",
    "    print(\"=\"*16)\n",
    "    # finetune your model here\n",
    "    # finetune(model)\n",
    "    # ..."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
